import numpy as np
from scipy.optimize import fsolve
from util import undersampled

def CORSA(design_mean, design_var, rho, design_used, n0):
    total_sample = np.sum(design_used)
    undersample = undersampled(len(design_mean), total_sample, design_used)

    if np.min(design_used) < n0:
        next_alternative = np.argmin(design_used)

    elif len(undersample) > 0:
        extracted_values = np.array(design_used[i] for i in undersample)
        min_idx = np.argmin(extracted_values)
        next_alternative = undersample[min_idx]

    else:
        best_id = np.argmax(design_mean)
        C = calculate_C_values(design_mean, design_var, rho)
        C_sorted = np.sort(C)
        C_sorted = np.concatenate(([0], C_sorted))
        g_opt = np.inf
        opt_omega = None
        K = len(design_mean)

        for i in range(len(C_sorted) - 1):
            C_lower = C_sorted[i]
            C_upper = C_sorted[i + 1]
            K1 = [a for a in range(K) if C[a] <= C_lower and a != best_id]
            K2 = [a for a in range(K) if C[a] > C_lower and a != best_id]

            left_end = feasible_region(K2, design_mean, design_var, rho, C_lower)
            if left_end > C_upper:
                continue
            else:
                if left_end > C_lower:
                    C_lower = left_end + 1e-1
                left_grad = evaluate_gradient(C_lower+(1e-2), design_mean, design_var, rho, K1, K2, best_id)
                right_grad = evaluate_gradient(C_upper-(1e-2), design_mean, design_var, rho, K1, K2, best_id)
                if left_grad < 0 and right_grad > 0:
                    omega, g_val = zero_point(best_id, design_mean, design_var, rho, K1, K2)
                else:
                    omega, g_val = cal_opt_ratio(best_id, design_mean, design_var, rho, K1, K2, C_lower)
                if g_val < g_opt:
                    g_opt = g_val
                    opt_omega = omega
        next_alternative = np.argmin(design_used - np.sum(design_used) * opt_omega)

    return next_alternative

def feasible_region(K2, design_mean, design_var, rho, lb):
    best_arm = np.argmax(design_mean)
    mu1 = design_mean[best_arm]
    var1 = design_var[best_arm]
    left_end = lb
    for a in K2:
        temp = 2 * var1 * (1 - rho**2) / (mu1 - design_mean[a]) ** 2
        left_end = max(left_end, temp)
    return left_end

def evaluate_gradient(x, design_mean, design_var, rho, K1, K2, best_id):
    K_prime = [a for a in range(len(design_mean)) if a != best_id]
    mu1 = design_mean[best_id]
    var1 = design_var[best_id]
    g_prime = 1.0

    for a in K_prime:
        r = np.sqrt(design_var[a] / var1)
        if a in K1:
            term = ((4 * var1 * design_var[a] * (1 - rho ** 2) * (rho * r - 1) ** 2) /
                    ((x * (mu1 - design_mean[a]) ** 2 - 2 * var1 * (rho * r - 1) ** 2) ** 2))

        elif a in K2:
            term = ((4 * var1 * design_var[a] * (1 - rho)**2 * (rho / r - 1)**2) /
                    ((x * (mu1 - design_mean[a])**2 - 2 * var1 * (1 - rho**2))**2))

        g_prime -= term

    return g_prime

def evaluate_g_val(x, design_mean, design_var, rho, K1, K2, best_id):
    K_prime = [a for a in range(len(design_mean)) if a != best_id]
    mu1 = design_mean[best_id]
    var1 = design_var[best_id]
    g_val = x

    for a in K_prime:
        r = np.sqrt(design_var[a] / var1)
        if a in K1:
            term = ((2 * design_var[a] * (1 - rho ** 2) * x) /
                    (x * (mu1 - design_mean[a]) ** 2 - 2 * var1 * (rho * r - 1) ** 2))
        elif a in K2:
            term = ((2 * design_var[a] * (rho / r - 1)**2 * x) /
                    (x * (mu1 - design_mean[a]) ** 2 - 2 * var1 * (1 - rho**2)))
        g_val += term

    return g_val


def calculate_C_values(design_mean, design_var, rho):
    K = len(design_mean)
    best_arm = np.argmax(design_mean)
    mu1 = design_mean[best_arm]
    var1 = design_var[best_arm]
    C = np.zeros(K)
    for a in range(K):
        if a == best_arm:
            C[a] = np.inf
        else:
            C[a] = 2 * (var1 + design_var[a] - 2 * rho * np.sqrt(var1 * design_var[a])) / (mu1 - design_mean[a])**2
    return C

def zero_point(best_id, design_mean, design_var, rho, K1, K2):
    initial_omega = np.ones(len(design_mean)) / len(design_mean)
    initial_common_value = 1.0
    initial_guess = np.append(initial_omega, initial_common_value)
    solution = fsolve(equations, initial_guess, args=(design_mean, design_var, rho, K1, K2))
    omega = solution[:len(design_mean)]
    omega = omega / np.sum(omega)
    Y = solution[-1]
    g_val = (1 / Y) *  omega[best_id]
    return omega, g_val

def equations(vars, design_mean, design_var, rho, K1, K2):
    K = design_mean.shape[0]
    best_id = np.argmax(design_mean)
    mu_best = np.max(design_mean)
    omega = vars[: K]
    common_value = vars[K]

    eq = []
    for idx in K1:
        r = np.sqrt(design_var[idx] / design_var[best_id])
        term = (design_mean[idx] - mu_best)**2 / (2 * ((design_var[best_id] * (rho * r - 1)**2 / omega[best_id]) + (design_var[idx] * (1 - rho**2) / omega[idx])))
        eq.append(term - common_value)

    for idx in K2:
        r = np.sqrt(design_var[idx] / design_var[best_id])
        term = (design_mean[idx] - mu_best)**2 / (2 * ((design_var[idx] * (rho/r - 1)**2 / omega[idx]) + (design_var[best_id] * (1 - rho**2) / omega[best_id])))
        eq.append(term - common_value)

    sum_K1 = np.sum([((design_var[best_id] * (rho * (np.sqrt(design_var[idx] / design_var[best_id]))-1)**2) / (design_var[idx] * (1-rho**2))) * omega[idx]**2 for idx in K1])
    sum_K2 = np.sum([((design_var[best_id] * (1-rho**2)) / (design_var[idx] * (rho/(np.sqrt(design_var[idx] / design_var[best_id]))-1)**2)) * omega[idx]**2 for idx in K2])
    eq.append(np.sqrt(sum_K1 + sum_K2) - omega[best_id])
    eq.append(np.sum(omega) - 1)

    return eq

def cal_opt_ratio(best_id, design_mean, design_var, rho, K1, K2, C_lower):
    x = C_lower
    g_val = evaluate_g_val(x, design_mean, design_var, rho, K1, K2, best_id)

    K_prime = [a for a in range(len(design_mean)) if a != best_id]
    omega = np.ones(len(design_mean))
    for a in K_prime:
        r = np.sqrt(design_var[a] / design_var[best_id])
        if a in K1:
            omega[a] = ((2 * design_var[a] * (1 - rho**2)) /
                        (x * (design_mean[best_id] - design_mean[a])**2 - 2 * design_var[best_id] * (rho * r -1)**2))
        elif a in K2:
            omega[a] = ((2 * design_var[a] * (rho / r - 1) ** 2) /
                        (x * (design_mean[best_id] - design_mean[a]) ** 2 - 2 * design_var[best_id] * (1 - rho ** 2)))
    omega = omega / np.sum(omega)
    return omega, g_val



